library(conText)
library(quanteda)
library(text2vec) # for implementation of GloVe algorithm
library(stringr) # to handle text strings
library(umap)
library(ggplot2)
library(dplyr)
library(tidytext)
library(readr)
# corpus
load("data/analysis/MPtweetsv2.Rdata")

# ================================ choice parameters
# ================================
WINDOW_SIZE <- 6
DIM <- 300
ITERS <- 100
COUNT_MIN <- 10

# shuffle text
set.seed(42L)
text <- sample(MPtweets$tweet)

# ================================ preprocess ================================

cr_text <- corpus(text)

cr_tokens <- tokens(cr_text, remove_punct=T, remove_symbols=T, 
                    remove_numbers=T, remove_separators=T)

tokens <- as.list(cr_tokens)

# ================================ create vocab ================================
# tokens <- space_tokenizer(text)
it <- itoken(tokens, progressbar = FALSE)
vocab <- create_vocabulary(it)
vocab_pruned <- prune_vocabulary(vocab, term_count_min = COUNT_MIN)  # keep only words that meet count threshold

# ================================ create term co-occurrence matrix
# ================================
vectorizer <- vocab_vectorizer(vocab_pruned)
tcm <- create_tcm(it, vectorizer, skip_grams_window = WINDOW_SIZE, skip_grams_window_context = "symmetric", 
                  weights = rep(1, WINDOW_SIZE))

# ================================ set model parameters
# ================================
glove <- GlobalVectors$new(rank = DIM, x_max = 100, learning_rate = 0.05)

# ================================ fit model ================================
word_vectors_main <- glove$fit_transform(tcm, n_iter = ITERS, convergence_tol = 0.001, 
                                         n_threads = RcppParallel::defaultNumThreads())


# ================================ get output ================================
word_vectors_context <- glove$components
glove_embedding <- word_vectors_main + t(word_vectors_context)  # word vectors

# ================================ save ================================
saveRDS(glove_embedding, file = "data/wordembeddings/local_glove.rds")